{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import json\n",
    "import cv2\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "\n",
    "from architecture import shufflenet"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load label decoding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('data/integer_encoding.json', 'r') as f:\n",
    "    encoding = json.load(f)\n",
    "    \n",
    "with open('data/wordnet_decoder.json', 'r') as f:\n",
    "    wordnet_decoder = json.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "decoder = {i: wordnet_decoder[n] for n, i in encoding.items()}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load an image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image = cv2.imread('panda.jpg')\n",
    "image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
    "image = cv2.resize(image, (224, 224), cv2.INTER_LINEAR)\n",
    "\n",
    "Image.fromarray(image)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load a trained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.reset_default_graph()\n",
    "\n",
    "raw_images = tf.placeholder(tf.uint8, [None, 224, 224, 3])\n",
    "images = tf.to_float(raw_images)/255.0\n",
    "\n",
    "logits = shufflenet(images, is_training=False, depth_multiplier='0.5')\n",
    "probabilities = tf.nn.softmax(logits, axis=1)\n",
    "\n",
    "ema = tf.train.ExponentialMovingAverage(decay=0.995)\n",
    "variables_to_restore = ema.variables_to_restore()\n",
    "saver = tf.train.Saver(variables_to_restore)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Predict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with tf.Session() as sess:\n",
    "    saver.restore(sess, 'run00/model.ckpt-1331064')\n",
    "    feed_dict = {raw_images: np.expand_dims(image, axis=0)}\n",
    "    result = sess.run(probabilities, feed_dict)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('The most probable labels is:')\n",
    "print(decoder[np.argmax(result)])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
